/*
 *   This program is free software: you can redistribute it and/or modify
 *   it under the terms of the GNU General Public License as published by
 *   the Free Software Foundation, either version 3 of the License, or
 *   (at your option) any later version.
 *
 *   This program is distributed in the hope that it will be useful,
 *   but WITHOUT ANY WARRANTY; without even the implied warranty of
 *   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *   GNU General Public License for more details.
 *
 *   You should have received a copy of the GNU General Public License
 *   along with this program.  If not, see <http://www.gnu.org/licenses/>.
 */

/*
 *    GeneralRegression.java
 *    Copyright (C) 2008-2012 University of Waikato, Hamilton, New Zealand
 *
 */

package weka.classifiers.pmml.consumer;

import java.io.Serializable;
import java.util.ArrayList;

import org.w3c.dom.Element;
import org.w3c.dom.Node;
import org.w3c.dom.NodeList;

import weka.core.Attribute;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Utils;
import weka.core.pmml.MiningSchema;
import weka.core.pmml.PMMLUtils;
import weka.core.pmml.TargetMetaInfo;

/**
 * Class implementing import of PMML General Regression model. Can be used as a
 * Weka classifier for prediction (buildClassifier() raises an Exception).
 *
 * @author Mark Hall (mhall{[at]}pentaho{[dot]}com)
 * @version $Revision$
 */
public class GeneralRegression extends PMMLClassifier implements Serializable {

    /**
     * For serialization
     */
    private static final long serialVersionUID = 2583880411828388959L;

    /**
     * Enumerated type for the model type.
     */
    enum ModelType {

        // same type of model
        REGRESSION("regression"), GENERALLINEAR("generalLinear"), MULTINOMIALLOGISTIC("multinomialLogistic"), ORDINALMULTINOMIAL("ordinalMultinomial"), GENERALIZEDLINEAR("generalizedLinear");

        private final String m_stringVal;

        ModelType(String name) {
            m_stringVal = name;
        }

        public String toString() {
            return m_stringVal;
        }
    }

    // the model type
    protected ModelType m_modelType = ModelType.REGRESSION;

    // the model name (if defined)
    protected String m_modelName;

    // the algorithm name (if defined)
    protected String m_algorithmName;

    // the function type (regression or classification)
    protected int m_functionType = Regression.RegressionTable.REGRESSION;

    /**
     * Enumerated type for the cumulative link function (ordinal multinomial model
     * type only).
     */
    enum CumulativeLinkFunction {
        NONE("none") {
            double eval(double value, double offset) {
                return Double.NaN; // no evaluation defined in this case!
            }
        },
        LOGIT("logit") {
            double eval(double value, double offset) {
                return 1.0 / (1.0 + Math.exp(-(value + offset)));
            }
        },
        PROBIT("probit") {
            double eval(double value, double offset) {
                return weka.core.matrix.Maths.pnorm(value + offset);
            }
        },
        CLOGLOG("cloglog") {
            double eval(double value, double offset) {
                return 1.0 - Math.exp(-Math.exp(value + offset));
            }
        },
        LOGLOG("loglog") {
            double eval(double value, double offset) {
                return Math.exp(-Math.exp(-(value + offset)));
            }
        },
        CAUCHIT("cauchit") {
            double eval(double value, double offset) {
                return 0.5 + (1.0 / Math.PI) * Math.atan(value + offset);
            }
        };

        /**
         * Evaluation function.
         * 
         * @param value  the raw response value
         * @param offset the offset to add to the raw value
         * @return the result of the link function
         */
        abstract double eval(double value, double offset);

        private final String m_stringVal;

        /**
         * Constructor
         * 
         * @param name textual name for this enum
         */
        CumulativeLinkFunction(String name) {
            m_stringVal = name;
        }

        /*
         * (non-Javadoc)
         * 
         * @see java.lang.Enum#toString()
         */
        public String toString() {
            return m_stringVal;
        }
    }

    // cumulative link function (ordinal multinomial only)
    protected CumulativeLinkFunction m_cumulativeLinkFunction = CumulativeLinkFunction.NONE;

    /**
     * Enumerated type for the link function (general linear and generalized linear
     * model types only).
     */
    enum LinkFunction {
        NONE("none") {
            double eval(double value, double offset, double trials, double distParam, double linkParam) {
                return Double.NaN; // no evaluation defined in this case!
            }
        },
        CLOGLOG("cloglog") {
            double eval(double value, double offset, double trials, double distParam, double linkParam) {
                return (1.0 - Math.exp(-Math.exp(value + offset))) * trials;
            }
        },
        IDENTITY("identity") {
            double eval(double value, double offset, double trials, double distParam, double linkParam) {
                return (value + offset) * trials;
            }
        },
        LOG("log") {
            double eval(double value, double offset, double trials, double distParam, double linkParam) {
                return Math.exp(value + offset) * trials;
            }
        },
        LOGC("logc") {
            double eval(double value, double offset, double trials, double distParam, double linkParam) {
                return (1.0 - Math.exp(value + offset)) * trials;
            }
        },
        LOGIT("logit") {
            double eval(double value, double offset, double trials, double distParam, double linkParam) {
                return (1.0 / (1.0 + Math.exp(-(value + offset)))) * trials;
            }
        },
        LOGLOG("loglog") {
            double eval(double value, double offset, double trials, double distParam, double linkParam) {
                return Math.exp(-Math.exp(-(value + offset))) * trials;
            }
        },
        NEGBIN("negbin") {
            double eval(double value, double offset, double trials, double distParam, double linkParam) {
                return (1.0 / (distParam * (Math.exp(-(value + offset)) - 1.0))) * trials;
            }
        },
        ODDSPOWER("oddspower") {
            double eval(double value, double offset, double trials, double distParam, double linkParam) {
                return (linkParam < 0.0 || linkParam > 0.0) ? (1.0 / (1.0 + Math.pow(1.0 + linkParam * (value + offset), (-1.0 / linkParam)))) * trials : (1.0 / (1.0 + Math.exp(-(value + offset)))) * trials;
            }
        },
        POWER("power") {
            double eval(double value, double offset, double trials, double distParam, double linkParam) {
                return (linkParam < 0.0 || linkParam > 0.0) ? Math.pow(value + offset, (1.0 / linkParam)) * trials : Math.exp(value + offset) * trials;
            }
        },
        PROBIT("probit") {
            double eval(double value, double offset, double trials, double distParam, double linkParam) {
                return weka.core.matrix.Maths.pnorm(value + offset) * trials;
            }
        };

        /**
         * Evaluation function.
         * 
         * @param value     the raw response value
         * @param offset    the offset to add to the raw value
         * @param trials    the trials value to multiply the result by
         * @param distParam the distribution parameter (negbin only)
         * @param linkParam the link parameter (power and oddspower only)
         * @return the result of the link function
         */
        abstract double eval(double value, double offset, double trials, double distParam, double linkParam);

        private final String m_stringVal;

        /**
         * Constructor.
         * 
         * @param name the textual name of this link function
         */
        LinkFunction(String name) {
            m_stringVal = name;
        }

        /*
         * (non-Javadoc)
         * 
         * @see java.lang.Enum#toString()
         */
        public String toString() {
            return m_stringVal;
        }
    }

    // link function (generalLinear model type only)
    protected LinkFunction m_linkFunction = LinkFunction.NONE;
    protected double m_linkParameter = Double.NaN;
    protected String m_trialsVariable;
    protected double m_trialsValue = Double.NaN;

    /**
     * Enumerated type for the distribution (general linear and generalized linear
     * model types only).
     */
    enum Distribution {
        NONE("none"), NORMAL("normal"), BINOMIAL("binomial"), GAMMA("gamma"), INVGAUSSIAN("igauss"), NEGBINOMIAL("negbin"), POISSON("poisson");

        private final String m_stringVal;

        Distribution(String name) {
            m_stringVal = name;
        }

        /*
         * (non-Javadoc)
         * 
         * @see java.lang.Enum#toString()
         */
        public String toString() {
            return m_stringVal;
        }
    }

    // generalLinear and generalizedLinear model type only
    protected Distribution m_distribution = Distribution.NORMAL;

    // ancillary parameter value for the negative binomial distribution
    protected double m_distParameter = Double.NaN;

    // if present, this variable is used during scoring
    // generalizedLinear/generalLinear or
    // ordinalMultinomial models
    protected String m_offsetVariable;

    // if present, this variable is used during scoring
    // generalizedLinear/generalLinear or
    // ordinalMultinomial models. It works like a user-specified intercept.
    // At most, only one of offsetVariable or offsetValue may be specified.
    protected double m_offsetValue = Double.NaN;

    /**
     * Small inner class to hold the name of a parameter plus its optional
     * descriptive label
     */
    static class Parameter implements Serializable {
        // ESCA-JAVA0096:
        /** For serialization */
        // CHECK ME WITH serialver
        private static final long serialVersionUID = 6502780192411755341L;

        protected String m_name = null;
        protected String m_label = null;
    }

    // List of model parameters
    protected ArrayList<Parameter> m_parameterList = new ArrayList<Parameter>();

    /**
     * Small inner class to hold the name of a factor or covariate, plus the index
     * of the attribute it corresponds to in the mining schema.
     */
    static class Predictor implements Serializable {
        /** For serialization */
        // CHECK ME WITH serialver
        private static final long serialVersionUID = 6502780192411755341L;

        protected String m_name = null;
        protected int m_miningSchemaIndex = -1;

        public String toString() {
            return m_name;
        }
    }

    // FactorList
    protected ArrayList<Predictor> m_factorList = new ArrayList<Predictor>();

    // CovariateList
    protected ArrayList<Predictor> m_covariateList = new ArrayList<Predictor>();

    /**
     * Small inner class to hold details on a predictor-to-parameter correlation.
     */
    static class PPCell implements Serializable {
        /** For serialization */
        // CHECK ME WITH serialver
        private static final long serialVersionUID = 6502780192411755341L;

        protected String m_predictorName = null;
        protected String m_parameterName = null;

        // either the exponent of a numeric attribute or the index of
        // a discrete value
        protected double m_value = 0;

        // optional. The default is for all target categories to
        // share the same PPMatrix.
        // TO-DO: implement multiple PPMatrixes
        protected String m_targetCategory = null;

    }

    // PPMatrix (predictor-to-parameter matrix)
    // rows = parameters, columns = predictors (attributes)
    protected PPCell[][] m_ppMatrix;

    /**
     * Small inner class to hold a single entry in the ParamMatrix (parameter
     * matrix).
     */
    static class PCell implements Serializable {

        /** For serialization */
        // CHECK ME WITH serialver
        private static final long serialVersionUID = 6502780192411755341L;

        // may be null for numeric target. May also be null if this coefficent
        // applies to all target categories.
        protected String m_targetCategory = null;
        protected String m_parameterName = null;
        // coefficient
        protected double m_beta = 0.0;
        // optional degrees of freedom
        protected int m_df = -1;
    }

    // ParamMatrix. rows = target categories (only one if target is numeric),
    // columns = parameters (in order that they occur in the parameter list).
    protected PCell[][] m_paramMatrix;

    /**
     * Constructs a GeneralRegression classifier.
     * 
     * @param model          the Element that holds the model definition
     * @param dataDictionary the data dictionary as a set of Instances
     * @param miningSchema   the mining schema
     * @throws Exception if there is a problem constructing the general regression
     *                   object from the PMML.
     */
    public GeneralRegression(Element model, Instances dataDictionary, MiningSchema miningSchema) throws Exception {

        super(dataDictionary, miningSchema);

        // get the model type
        String mType = model.getAttribute("modelType");
        boolean found = false;
        for (ModelType m : ModelType.values()) {
            if (m.toString().equals(mType)) {
                m_modelType = m;
                found = true;
                break;
            }
        }
        if (!found) {
            throw new Exception("[GeneralRegression] unknown model type: " + mType);
        }

        if (m_modelType == ModelType.ORDINALMULTINOMIAL) {
            // get the cumulative link function
            String cLink = model.getAttribute("cumulativeLink");
            found = false;
            for (CumulativeLinkFunction c : CumulativeLinkFunction.values()) {
                if (c.toString().equals(cLink)) {
                    m_cumulativeLinkFunction = c;
                    found = true;
                    break;
                }
            }
            if (!found) {
                throw new Exception("[GeneralRegression] cumulative link function " + cLink);
            }
        } else if (m_modelType == ModelType.GENERALIZEDLINEAR || m_modelType == ModelType.GENERALLINEAR) {
            // get the link function
            String link = model.getAttribute("linkFunction");
            found = false;
            for (LinkFunction l : LinkFunction.values()) {
                if (l.toString().equals(link)) {
                    m_linkFunction = l;
                    found = true;
                    break;
                }
            }
            if (!found) {
                throw new Exception("[GeneralRegression] unknown link function " + link);
            }

            // get the link parameter
            String linkP = model.getAttribute("linkParameter");
            if (linkP != null && linkP.length() > 0) {
                try {
                    m_linkParameter = Double.parseDouble(linkP);
                } catch (IllegalArgumentException ex) {
                    throw new Exception("[GeneralRegression] unable to parse the link parameter");
                }
            }

            // get the trials variable
            String trials = model.getAttribute("trialsVariable");
            if (trials != null && trials.length() > 0) {
                m_trialsVariable = trials;
            }

            // get the trials value
            String trialsV = model.getAttribute("trialsValue");
            if (trialsV != null && trialsV.length() > 0) {
                try {
                    m_trialsValue = Double.parseDouble(trialsV);
                } catch (IllegalArgumentException ex) {
                    throw new Exception("[GeneralRegression] unable to parse the trials value");
                }
            }
        }

        String mName = model.getAttribute("modelName");
        if (mName != null && mName.length() > 0) {
            m_modelName = mName;
        }

        String fName = model.getAttribute("functionName");
        if (fName.equals("classification")) {
            m_functionType = Regression.RegressionTable.CLASSIFICATION;
        }

        String algName = model.getAttribute("algorithmName");
        if (algName != null && algName.length() > 0) {
            m_algorithmName = algName;
        }

        String distribution = model.getAttribute("distribution");
        if (distribution != null && distribution.length() > 0) {
            found = false;
            for (Distribution d : Distribution.values()) {
                if (d.toString().equals(distribution)) {
                    m_distribution = d;
                    found = true;
                    break;
                }
            }
            if (!found) {
                throw new Exception("[GeneralRegression] unknown distribution type " + distribution);
            }
        }

        String distP = model.getAttribute("distParameter");
        if (distP != null && distP.length() > 0) {
            try {
                m_distParameter = Double.parseDouble(distP);
            } catch (IllegalArgumentException ex) {
                throw new Exception("[GeneralRegression] unable to parse the distribution parameter");
            }
        }

        String offsetV = model.getAttribute("offsetVariable");
        if (offsetV != null && offsetV.length() > 0) {
            m_offsetVariable = offsetV;
        }

        String offsetVal = model.getAttribute("offsetValue");
        if (offsetVal != null && offsetVal.length() > 0) {
            try {
                m_offsetValue = Double.parseDouble(offsetVal);
            } catch (IllegalArgumentException ex) {
                throw new Exception("[GeneralRegression] unable to parse the offset value");
            }
        }

        // get the parameter list
        readParameterList(model);

        // get the factors and covariates
        readFactorsAndCovariates(model, "FactorList");
        readFactorsAndCovariates(model, "CovariateList");

        // read the PPMatrix
        readPPMatrix(model);

        // read the parameter estimates
        readParamMatrix(model);
    }

    /**
     * Read the list of parameters.
     *
     * @param model the Element that contains the model
     * @throws Exception if there is some problem with extracting the parameters.
     */
    protected void readParameterList(Element model) throws Exception {
        NodeList paramL = model.getElementsByTagName("ParameterList");

        // should be just one parameter list
        if (paramL.getLength() == 1) {
            Node paramN = paramL.item(0);
            if (paramN.getNodeType() == Node.ELEMENT_NODE) {
                NodeList parameterList = ((Element) paramN).getElementsByTagName("Parameter");
                for (int i = 0; i < parameterList.getLength(); i++) {
                    Node parameter = parameterList.item(i);
                    if (parameter.getNodeType() == Node.ELEMENT_NODE) {
                        Parameter p = new Parameter();
                        p.m_name = ((Element) parameter).getAttribute("name");
                        String label = ((Element) parameter).getAttribute("label");
                        if (label != null && label.length() > 0) {
                            p.m_label = label;
                        }
                        m_parameterList.add(p);
                    }
                }
            }
        } else {
            throw new Exception("[GeneralRegression] more than one parameter list!");
        }
    }

    /**
     * Read the lists of factors and covariates.
     *
     * @param model             the Element that contains the model
     * @param factorOrCovariate holds the String "FactorList" or "CovariateList"
     * @throws Exception if there is a factor or covariate listed that isn't in the
     *                   mining schema
     */
    protected void readFactorsAndCovariates(Element model, String factorOrCovariate) throws Exception {
        Instances miningSchemaI = m_miningSchema.getFieldsAsInstances();

        NodeList factorL = model.getElementsByTagName(factorOrCovariate);
        if (factorL.getLength() == 1) { // should be 0 or 1 FactorList element
            Node factor = factorL.item(0);
            if (factor.getNodeType() == Node.ELEMENT_NODE) {
                NodeList predL = ((Element) factor).getElementsByTagName("Predictor");
                for (int i = 0; i < predL.getLength(); i++) {
                    Node pred = predL.item(i);
                    if (pred.getNodeType() == Node.ELEMENT_NODE) {
                        Predictor p = new Predictor();
                        p.m_name = ((Element) pred).getAttribute("name");
                        // find the index of this predictor in the mining schema
                        boolean found = false;
                        for (int j = 0; j < miningSchemaI.numAttributes(); j++) {
                            if (miningSchemaI.attribute(j).name().equals(p.m_name)) {
                                found = true;
                                p.m_miningSchemaIndex = j;
                                break;
                            }
                        }
                        if (found) {
                            if (factorOrCovariate.equals("FactorList")) {
                                m_factorList.add(p);
                            } else {
                                m_covariateList.add(p);
                            }
                        } else {
                            throw new Exception("[GeneralRegression] reading factors and covariates - " + "unable to find predictor " + p.m_name + " in the mining schema");
                        }
                    }
                }
            }
        } else if (factorL.getLength() > 1) {
            throw new Exception("[GeneralRegression] more than one " + factorOrCovariate + "! ");
        }
    }

    /**
     * Read the PPMatrix from the xml. Does not handle multiple PPMatrixes yet.
     *
     * @param model the Element that contains the model
     * @throws Exception if there is a problem parsing cell values.
     */
    protected void readPPMatrix(Element model) throws Exception {
        Instances miningSchemaI = m_miningSchema.getFieldsAsInstances();

        NodeList matrixL = model.getElementsByTagName("PPMatrix");

        // should be exactly one PPMatrix
        if (matrixL.getLength() == 1) {
            // allocate space for the matrix
            // column that corresponds to the class will be empty (and will be missed out
            // when printing the model).
            m_ppMatrix = new PPCell[m_parameterList.size()][miningSchemaI.numAttributes()];

            Node ppM = matrixL.item(0);
            if (ppM.getNodeType() == Node.ELEMENT_NODE) {
                NodeList cellL = ((Element) ppM).getElementsByTagName("PPCell");
                for (int i = 0; i < cellL.getLength(); i++) {
                    Node cell = cellL.item(i);
                    if (cell.getNodeType() == Node.ELEMENT_NODE) {
                        String predictorName = ((Element) cell).getAttribute("predictorName");
                        String parameterName = ((Element) cell).getAttribute("parameterName");
                        String value = ((Element) cell).getAttribute("value");
                        double expOrIndex = -1;
                        int predictorIndex = -1;
                        int parameterIndex = -1;
                        for (int j = 0; j < m_parameterList.size(); j++) {
                            if (m_parameterList.get(j).m_name.equals(parameterName)) {
                                parameterIndex = j;
                                break;
                            }
                        }
                        if (parameterIndex == -1) {
                            throw new Exception("[GeneralRegression] unable to find parameter name " + parameterName + " in parameter list");
                        }

                        Predictor p = getCovariate(predictorName);
                        if (p != null) {
                            try {
                                expOrIndex = Double.parseDouble(value);
                                predictorIndex = p.m_miningSchemaIndex;
                            } catch (IllegalArgumentException ex) {
                                throw new Exception("[GeneralRegression] unable to parse PPCell value: " + value);
                            }
                        } else {
                            // try as a factor
                            p = getFactor(predictorName);
                            if (p != null) {
                                // An example pmml file from DMG seems to suggest that it
                                // is possible for a continuous variable in the mining schema
                                // to be treated as a factor, so we have to check for this
                                if (miningSchemaI.attribute(p.m_miningSchemaIndex).isNumeric()) {
                                    // parse this value as a double. It will be treated as a value
                                    // to match rather than an exponent since we are dealing with
                                    // a factor here
                                    try {
                                        expOrIndex = Double.parseDouble(value);
                                    } catch (IllegalArgumentException ex) {
                                        throw new Exception("[GeneralRegresion] unable to parse PPCell value: " + value);
                                    }
                                } else {
                                    // it is a nominal attribute in the mining schema so find
                                    // the index that correponds to this value
                                    Attribute att = miningSchemaI.attribute(p.m_miningSchemaIndex);
                                    expOrIndex = att.indexOfValue(value);
                                    if (expOrIndex == -1) {
                                        throw new Exception("[GeneralRegression] unable to find PPCell value " + value + " in mining schema attribute " + att.name());
                                    }
                                }
                            } else {
                                throw new Exception("[GeneralRegression] cant find predictor " + predictorName + "in either the factors list " + "or the covariates list");
                            }
                            predictorIndex = p.m_miningSchemaIndex;
                        }

                        // fill in cell value
                        PPCell ppc = new PPCell();
                        ppc.m_predictorName = predictorName;
                        ppc.m_parameterName = parameterName;
                        ppc.m_value = expOrIndex;

                        // TO-DO: ppc.m_targetCategory (when handling for multiple PPMatrixes is
                        // implemented)
                        m_ppMatrix[parameterIndex][predictorIndex] = ppc;
                    }
                }
            }
        } else {
            throw new Exception("[GeneralRegression] more than one PPMatrix!");
        }
    }

    private Predictor getCovariate(String predictorName) {
        for (int i = 0; i < m_covariateList.size(); i++) {
            if (predictorName.equals(m_covariateList.get(i).m_name)) {
                return m_covariateList.get(i);
            }
        }
        return null;
    }

    private Predictor getFactor(String predictorName) {
        for (int i = 0; i < m_factorList.size(); i++) {
            if (predictorName.equals(m_factorList.get(i).m_name)) {
                return m_factorList.get(i);
            }
        }
        return null;
    }

    /**
     * Read the parameter matrix from the xml.
     * 
     * @param model Element that holds the model
     * @throws Exception if a problem is encountered during extraction of the
     *                   parameter matrix
     */
    private void readParamMatrix(Element model) throws Exception {

        Instances miningSchemaI = m_miningSchema.getFieldsAsInstances();
        Attribute classAtt = miningSchemaI.classAttribute();
        // used when function type is classification but class attribute is numeric
        // in the mining schema. We will assume that there is a Target specified in
        // the pmml that defines the legal values for this class.
        ArrayList<String> targetVals = null;

        NodeList matrixL = model.getElementsByTagName("ParamMatrix");
        if (matrixL.getLength() != 1) {
            throw new Exception("[GeneralRegression] more than one ParamMatrix!");
        }
        Element matrix = (Element) matrixL.item(0);

        // check for the case where the class in the mining schema is numeric,
        // but this attribute is treated as discrete
        if (m_functionType == Regression.RegressionTable.CLASSIFICATION && classAtt.isNumeric()) {
            // try and convert the class attribute to nominal. For this to succeed
            // there has to be a Target element defined in the PMML.
            if (!m_miningSchema.hasTargetMetaData()) {
                throw new Exception("[GeneralRegression] function type is classification and " + "class attribute in mining schema is numeric, however, " + "there is no Target element " + "specifying legal discrete values for the target!");

            }

            if (m_miningSchema.getTargetMetaData().getOptype() != TargetMetaInfo.Optype.CATEGORICAL) {
                throw new Exception("[GeneralRegression] function type is classification and " + "class attribute in mining schema is numeric, however " + "Target element in PMML does not have optype categorical!");
            }

            // OK now get legal values
            targetVals = m_miningSchema.getTargetMetaData().getValues();
            if (targetVals.size() == 0) {
                throw new Exception("[GeneralRegression] function type is classification and " + "class attribute in mining schema is numeric, however " + "Target element in PMML does not have any discrete values " + "defined!");
            }

            // Finally, convert the class in the mining schema to nominal
            m_miningSchema.convertNumericAttToNominal(miningSchemaI.classIndex(), targetVals);
        }

        // allocate space for the matrix
        m_paramMatrix = new PCell[(classAtt.isNumeric()) ? 1 : classAtt.numValues()][m_parameterList.size()];

        NodeList pcellL = matrix.getElementsByTagName("PCell");
        for (int i = 0; i < pcellL.getLength(); i++) {
            // indicates that that this beta applies to all target categories
            // or target is numeric
            int targetCategoryIndex = -1;
            int parameterIndex = -1;
            Node pcell = pcellL.item(i);
            if (pcell.getNodeType() == Node.ELEMENT_NODE) {
                String paramName = ((Element) pcell).getAttribute("parameterName");
                String targetCatName = ((Element) pcell).getAttribute("targetCategory");
                String coefficient = ((Element) pcell).getAttribute("beta");
                String df = ((Element) pcell).getAttribute("df");

                for (int j = 0; j < m_parameterList.size(); j++) {
                    if (m_parameterList.get(j).m_name.equals(paramName)) {
                        parameterIndex = j;
                        // use the label if defined
                        if (m_parameterList.get(j).m_label != null) {
                            paramName = m_parameterList.get(j).m_label;
                        }
                        break;
                    }
                }
                if (parameterIndex == -1) {
                    throw new Exception("[GeneralRegression] unable to find parameter name " + paramName + " in parameter list");
                }

                if (targetCatName != null && targetCatName.length() > 0) {
                    if (classAtt.isNominal() || classAtt.isString()) {
                        targetCategoryIndex = classAtt.indexOfValue(targetCatName);
                    } else {
                        throw new Exception("[GeneralRegression] found a PCell with a named " + "target category: " + targetCatName + " but class attribute is numeric in " + "mining schema");
                    }
                }

                PCell p = new PCell();
                if (targetCategoryIndex != -1) {
                    p.m_targetCategory = targetCatName;
                }
                p.m_parameterName = paramName;
                try {
                    p.m_beta = Double.parseDouble(coefficient);
                } catch (IllegalArgumentException ex) {
                    throw new Exception("[GeneralRegression] unable to parse beta value " + coefficient + " as a double from PCell");
                }
                if (df != null && df.length() > 0) {
                    try {
                        p.m_df = Integer.parseInt(df);
                    } catch (IllegalArgumentException ex) {
                        throw new Exception("[GeneralRegression] unable to parse df value " + df + " as an int from PCell");
                    }
                }

                if (targetCategoryIndex != -1) {
                    m_paramMatrix[targetCategoryIndex][parameterIndex] = p;
                } else {
                    // this PCell to all target categories (covers numeric class, in
                    // which case there will be only one row in the matrix anyway)
                    for (int j = 0; j < m_paramMatrix.length; j++) {
                        m_paramMatrix[j][parameterIndex] = p;
                    }
                }
            }
        }
    }

    /**
     * Return a textual description of this general regression.
     * 
     * @return a description of this general regression
     */
    public String toString() {
        StringBuffer temp = new StringBuffer();
        temp.append("PMML version " + getPMMLVersion());
        if (!getCreatorApplication().equals("?")) {
            temp.append("\nApplication: " + getCreatorApplication());
        }
        temp.append("\nPMML Model: " + m_modelType);
        temp.append("\n\n");
        temp.append(m_miningSchema);

        if (m_factorList.size() > 0) {
            temp.append("Factors:\n");
            for (Predictor p : m_factorList) {
                temp.append("\t" + p + "\n");
            }
        }
        temp.append("\n");
        if (m_covariateList.size() > 0) {
            temp.append("Covariates:\n");
            for (Predictor p : m_covariateList) {
                temp.append("\t" + p + "\n");
            }
        }
        temp.append("\n");

        printPPMatrix(temp);
        temp.append("\n");
        printParameterMatrix(temp);

        // do the link function stuff
        temp.append("\n");

        if (m_linkFunction != LinkFunction.NONE) {
            temp.append("Link function: " + m_linkFunction);
            if (m_offsetVariable != null) {
                temp.append("\n\tOffset variable " + m_offsetVariable);
            } else if (!Double.isNaN(m_offsetValue)) {
                temp.append("\n\tOffset value " + m_offsetValue);
            }

            if (m_trialsVariable != null) {
                temp.append("\n\tTrials variable " + m_trialsVariable);
            } else if (!Double.isNaN(m_trialsValue)) {
                temp.append("\n\tTrials value " + m_trialsValue);
            }

            if (m_distribution != Distribution.NONE) {
                temp.append("\nDistribution: " + m_distribution);
            }

            if (m_linkFunction == LinkFunction.NEGBIN && m_distribution == Distribution.NEGBINOMIAL && !Double.isNaN(m_distParameter)) {
                temp.append("\n\tDistribution parameter " + m_distParameter);
            }

            if (m_linkFunction == LinkFunction.POWER || m_linkFunction == LinkFunction.ODDSPOWER) {
                if (!Double.isNaN(m_linkParameter)) {
                    temp.append("\n\nLink parameter " + m_linkParameter);
                }
            }
        }

        if (m_cumulativeLinkFunction != CumulativeLinkFunction.NONE) {
            temp.append("Cumulative link function: " + m_cumulativeLinkFunction);

            if (m_offsetVariable != null) {
                temp.append("\n\tOffset variable " + m_offsetVariable);
            } else if (!Double.isNaN(m_offsetValue)) {
                temp.append("\n\tOffset value " + m_offsetValue);
            }
        }
        temp.append("\n");

        return temp.toString();
    }

    /**
     * Format and print the PPMatrix to the supplied StringBuffer.
     * 
     * @param buff the StringBuffer to append to
     */
    protected void printPPMatrix(StringBuffer buff) {
        Instances miningSchemaI = m_miningSchema.getFieldsAsInstances();
        int maxAttWidth = 0;
        for (int i = 0; i < miningSchemaI.numAttributes(); i++) {
            Attribute a = miningSchemaI.attribute(i);
            if (a.name().length() > maxAttWidth) {
                maxAttWidth = a.name().length();
            }
        }

        // check the width of the values
        for (int i = 0; i < m_parameterList.size(); i++) {
            for (int j = 0; j < miningSchemaI.numAttributes(); j++) {
                if (m_ppMatrix[i][j] != null) {
                    double width = Math.log(Math.abs(m_ppMatrix[i][j].m_value)) / Math.log(10.0);
                    if (width < 0) {
                        width = 1;
                    }
                    // decimal + # decimal places + 1
                    width += 2.0;
                    if ((int) width > maxAttWidth) {
                        maxAttWidth = (int) width;
                    }
                    if (miningSchemaI.attribute(j).isNominal() || miningSchemaI.attribute(j).isString()) {
                        // check the width of this value
                        String val = miningSchemaI.attribute(j).value((int) m_ppMatrix[i][j].m_value) + " ";
                        if (val.length() > maxAttWidth) {
                            maxAttWidth = val.length();
                        }
                    }
                }
            }
        }

        // get the max parameter width
        int maxParamWidth = "Parameter  ".length();
        for (Parameter p : m_parameterList) {
            String temp = (p.m_label != null) ? p.m_label + " " : p.m_name + " ";

            if (temp.length() > maxParamWidth) {
                maxParamWidth = temp.length();
            }
        }

        buff.append("Predictor-to-Parameter matrix:\n");
        buff.append(PMMLUtils.pad("Predictor", " ", (maxParamWidth + (maxAttWidth * 2 + 2)) - "Predictor".length(), true));
        buff.append("\n" + PMMLUtils.pad("Parameter", " ", maxParamWidth - "Parameter".length(), false));
        // attribute names
        for (int i = 0; i < miningSchemaI.numAttributes(); i++) {
            if (i != miningSchemaI.classIndex()) {
                String attName = miningSchemaI.attribute(i).name();
                buff.append(PMMLUtils.pad(attName, " ", maxAttWidth + 1 - attName.length(), true));
            }
        }
        buff.append("\n");

        for (int i = 0; i < m_parameterList.size(); i++) {
            Parameter param = m_parameterList.get(i);
            String paramS = (param.m_label != null) ? param.m_label : param.m_name;
            buff.append(PMMLUtils.pad(paramS, " ", maxParamWidth - paramS.length(), false));
            for (int j = 0; j < miningSchemaI.numAttributes(); j++) {
                if (j != miningSchemaI.classIndex()) {
                    PPCell p = m_ppMatrix[i][j];
                    String val = " ";
                    if (p != null) {
                        if (miningSchemaI.attribute(j).isNominal() || miningSchemaI.attribute(j).isString()) {
                            val = miningSchemaI.attribute(j).value((int) p.m_value);
                        } else {
                            val = "" + Utils.doubleToString(p.m_value, maxAttWidth, 4).trim();
                        }
                    }
                    buff.append(PMMLUtils.pad(val, " ", maxAttWidth + 1 - val.length(), true));
                }
            }
            buff.append("\n");
        }
    }

    /**
     * Format and print the parameter matrix to the supplied StringBuffer.
     * 
     * @param buff the StringBuffer to append to
     */
    protected void printParameterMatrix(StringBuffer buff) {
        Instances miningSchemaI = m_miningSchema.getFieldsAsInstances();

        // get the maximum class value width (nominal)
        int maxClassWidth = miningSchemaI.classAttribute().name().length();
        if (miningSchemaI.classAttribute().isNominal() || miningSchemaI.classAttribute().isString()) {
            for (int i = 0; i < miningSchemaI.classAttribute().numValues(); i++) {
                if (miningSchemaI.classAttribute().value(i).length() > maxClassWidth) {
                    maxClassWidth = miningSchemaI.classAttribute().value(i).length();
                }
            }
        }

        // get the maximum parameter name/label width
        int maxParamWidth = 0;
        for (int i = 0; i < m_parameterList.size(); i++) {
            Parameter p = m_parameterList.get(i);
            String val = (p.m_label != null) ? p.m_label + " " : p.m_name + " ";
            if (val.length() > maxParamWidth) {
                maxParamWidth = val.length();
            }
        }

        // get the max beta value width
        int maxBetaWidth = "Coeff.".length();
        for (int i = 0; i < m_paramMatrix.length; i++) {
            for (int j = 0; j < m_parameterList.size(); j++) {
                PCell p = m_paramMatrix[i][j];
                if (p != null) {
                    double width = Math.log(Math.abs(p.m_beta)) / Math.log(10);
                    if (width < 0) {
                        width = 1;
                    }
                    // decimal + # decimal places + 1
                    width += 7.0;
                    if ((int) width > maxBetaWidth) {
                        maxBetaWidth = (int) width;
                    }
                }
            }
        }

        buff.append("Parameter estimates:\n");
        buff.append(PMMLUtils.pad(miningSchemaI.classAttribute().name(), " ", maxClassWidth + maxParamWidth + 2 - miningSchemaI.classAttribute().name().length(), false));
        buff.append(PMMLUtils.pad("Coeff.", " ", maxBetaWidth + 1 - "Coeff.".length(), true));
        buff.append(PMMLUtils.pad("df", " ", maxBetaWidth - "df".length(), true));
        buff.append("\n");
        for (int i = 0; i < m_paramMatrix.length; i++) {
            // scan for non-null entry for this class value
            boolean ok = false;
            for (int j = 0; j < m_parameterList.size(); j++) {
                if (m_paramMatrix[i][j] != null) {
                    ok = true;
                }
            }
            if (!ok) {
                continue;
            }
            // first the class value (if nominal)
            String cVal = (miningSchemaI.classAttribute().isNominal() || miningSchemaI.classAttribute().isString()) ? miningSchemaI.classAttribute().value(i) : " ";
            buff.append(PMMLUtils.pad(cVal, " ", maxClassWidth - cVal.length(), false));
            buff.append("\n");
            for (int j = 0; j < m_parameterList.size(); j++) {
                PCell p = m_paramMatrix[i][j];
                if (p != null) {
                    String label = p.m_parameterName;
                    buff.append(PMMLUtils.pad(label, " ", maxClassWidth + maxParamWidth + 2 - label.length(), true));
                    String betaS = Utils.doubleToString(p.m_beta, maxBetaWidth, 4).trim();
                    buff.append(PMMLUtils.pad(betaS, " ", maxBetaWidth + 1 - betaS.length(), true));
                    String dfS = Utils.doubleToString(p.m_df, maxBetaWidth, 4).trim();
                    buff.append(PMMLUtils.pad(dfS, " ", maxBetaWidth - dfS.length(), true));
                    buff.append("\n");
                }
            }
        }
    }

    /**
     * Construct the incoming parameter vector based on the values in the incoming
     * test instance.
     * 
     * @param incomingInst the values of the incoming test instance
     * @return the populated parameter vector ready to be multiplied against the
     *         vector of coefficients.
     * @throws Exception if there is some problem whilst constructing the parameter
     *                   vector
     */
    private double[] incomingParamVector(double[] incomingInst) throws Exception {
        Instances miningSchemaI = m_miningSchema.getFieldsAsInstances();
        double[] incomingPV = new double[m_parameterList.size()];

        for (int i = 0; i < m_parameterList.size(); i++) {
            //
            // default is that this row represents the intercept.
            // this will be the case if there are all null entries in this row
            incomingPV[i] = 1.0;

            // loop over the attributes (predictors)
            for (int j = 0; j < miningSchemaI.numAttributes(); j++) {
                PPCell cellEntry = m_ppMatrix[i][j];
                Predictor p = null;
                if (cellEntry != null) {
                    if ((p = getFactor(cellEntry.m_predictorName)) != null) {
                        if ((int) incomingInst[p.m_miningSchemaIndex] == (int) cellEntry.m_value) {
                            incomingPV[i] *= 1.0; // we have a match
                        } else {
                            incomingPV[i] *= 0.0;
                        }
                    } else if ((p = getCovariate(cellEntry.m_predictorName)) != null) {
                        incomingPV[i] *= Math.pow(incomingInst[p.m_miningSchemaIndex], cellEntry.m_value);
                    } else {
                        throw new Exception("[GeneralRegression] can't find predictor " + cellEntry.m_predictorName + " in either the list of factors or covariates");
                    }
                }
            }
        }

        return incomingPV;
    }

    /**
     * Classifies the given test instance. The instance has to belong to a dataset
     * when it's being classified.
     * 
     * @param inst the instance to be classified
     * @return the predicted most likely class for the instance or
     *         Utils.missingValue() if no prediction is made
     * @exception Exception if an error occurred during the prediction
     */
    public double[] distributionForInstance(Instance inst) throws Exception {
        if (!m_initialized) {
            mapToMiningSchema(inst.dataset());
        }
        double[] preds = null;
        if (m_miningSchema.getFieldsAsInstances().classAttribute().isNumeric()) {
            preds = new double[1];
        } else {
            preds = new double[m_miningSchema.getFieldsAsInstances().classAttribute().numValues()];
        }

        // create an array of doubles that holds values from the incoming
        // instance; in order of the fields in the mining schema. We will
        // also handle missing values and outliers here.
        double[] incoming = m_fieldsMap.instanceToSchema(inst, m_miningSchema);

        // In this implementation we will default to information in the Target element
        // (default
        // value for numeric prediction and prior probabilities for classification). If
        // there is
        // no Target element defined, then an Exception is thrown.

        boolean hasMissing = false;
        for (int i = 0; i < incoming.length; i++) {
            if (i != m_miningSchema.getFieldsAsInstances().classIndex() && Double.isNaN(incoming[i])) {
                hasMissing = true;
                break;
            }
        }

        if (hasMissing) {
            if (!m_miningSchema.hasTargetMetaData()) {
                String message = "[GeneralRegression] WARNING: Instance to predict has missing value(s) but " + "there is no missing value handling meta data and no " + "prior probabilities/default value to fall back to. No " + "prediction will be made (" + ((m_miningSchema.getFieldsAsInstances().classAttribute().isNominal() || m_miningSchema.getFieldsAsInstances().classAttribute().isString()) ? "zero probabilities output)." : "NaN output).");
                if (m_log == null) {
                    System.err.println(message);
                } else {
                    m_log.warn(message);
                }

                if (m_miningSchema.getFieldsAsInstances().classAttribute().isNumeric()) {
                    preds[0] = Utils.missingValue();
                }
                return preds;
            } else {
                // use prior probablilities/default value
                TargetMetaInfo targetData = m_miningSchema.getTargetMetaData();
                if (m_miningSchema.getFieldsAsInstances().classAttribute().isNumeric()) {
                    preds[0] = targetData.getDefaultValue();
                } else {
                    Instances miningSchemaI = m_miningSchema.getFieldsAsInstances();
                    for (int i = 0; i < miningSchemaI.classAttribute().numValues(); i++) {
                        preds[i] = targetData.getPriorProbability(miningSchemaI.classAttribute().value(i));
                    }
                }
                return preds;
            }
        } else {
            // construct input parameter vector here
            double[] inputParamVector = incomingParamVector(incoming);
            computeResponses(incoming, inputParamVector, preds);
        }

        return preds;
    }

    /**
     * Compute the responses for the function given the parameter values
     * corresponding to the current incoming instance.
     * 
     * @param incomingInst        raw incoming instance values (after missing value
     *                            replacement and outlier treatment)
     * @param incomingParamVector incoming instance values mapped to parameters
     * @param responses           will contain the responses computed by the
     *                            function
     * @throws Exception if something goes wrong
     */
    private void computeResponses(double[] incomingInst, double[] incomingParamVector, double[] responses) throws Exception {
        for (int i = 0; i < responses.length; i++) {
            for (int j = 0; j < m_parameterList.size(); j++) {
                // a row of the parameter matrix should have all non-null entries
                // except for the last class (in the case of classification) which
                // should have just an intercept of 0. Need to handle the case where
                // no intercept has been defined in the pmml file for the last class
                PCell p = m_paramMatrix[i][j];
                if (p == null) {
                    responses[i] += 0.0 * incomingParamVector[j];
                } else {
                    responses[i] += incomingParamVector[j] * p.m_beta;
                }
            }
        }

        switch (m_modelType) {
        case MULTINOMIALLOGISTIC:
            computeProbabilitiesMultinomialLogistic(responses);
            break;
        case REGRESSION:
            // nothing to be done
            break;
        case GENERALLINEAR:
        case GENERALIZEDLINEAR:
            if (m_linkFunction != LinkFunction.NONE) {
                computeResponseGeneralizedLinear(incomingInst, responses);
            } else {
                throw new Exception("[GeneralRegression] no link function specified!");
            }
            break;
        case ORDINALMULTINOMIAL:
            if (m_cumulativeLinkFunction != CumulativeLinkFunction.NONE) {
                computeResponseOrdinalMultinomial(incomingInst, responses);
            } else {
                throw new Exception("[GeneralRegression] no cumulative link function specified!");
            }
            break;
        default:
            throw new Exception("[GeneralRegression] unknown model type");
        }
    }

    /**
     * Computes probabilities for the multinomial logistic model type.
     * 
     * @param responses will hold the responses computed by the function.
     */
    private static void computeProbabilitiesMultinomialLogistic(double[] responses) {
        double[] r = responses.clone();
        for (int j = 0; j < r.length; j++) {
            double sum = 0;
            boolean overflow = false;
            for (int k = 0; k < r.length; k++) {
                if (r[k] - r[j] > 700) {
                    overflow = true;
                    break;
                }
                sum += Math.exp(r[k] - r[j]);
            }
            if (overflow) {
                responses[j] = 0.0;
            } else {
                responses[j] = 1.0 / sum;
            }
        }
    }

    /**
     * Computes responses for the general linear and generalized linear model types.
     * 
     * @param incomingInst the raw incoming instance values (after missing value
     *                     replacement and outlier treatment etc).
     * @param responses    will hold the responses computed by the function
     * @throws Exception if a problem occurs.
     */
    private void computeResponseGeneralizedLinear(double[] incomingInst, double[] responses) throws Exception {
        double[] r = responses.clone();

        double offset = 0;
        if (m_offsetVariable != null) {
            Attribute offsetAtt = m_miningSchema.getFieldsAsInstances().attribute(m_offsetVariable);
            if (offsetAtt == null) {
                throw new Exception("[GeneralRegression] unable to find offset variable " + m_offsetVariable + " in the mining schema!");
            }
            offset = incomingInst[offsetAtt.index()];
        } else if (!Double.isNaN(m_offsetValue)) {
            offset = m_offsetValue;
        }

        double trials = 1;
        if (m_trialsVariable != null) {
            Attribute trialsAtt = m_miningSchema.getFieldsAsInstances().attribute(m_trialsVariable);
            if (trialsAtt == null) {
                throw new Exception("[GeneralRegression] unable to find trials variable " + m_trialsVariable + " in the mining schema!");
            }
            trials = incomingInst[trialsAtt.index()];
        } else if (!Double.isNaN(m_trialsValue)) {
            trials = m_trialsValue;
        }

        double distParam = 0;
        if (m_linkFunction == LinkFunction.NEGBIN && m_distribution == Distribution.NEGBINOMIAL) {
            if (Double.isNaN(m_distParameter)) {
                throw new Exception("[GeneralRegression] no distribution parameter defined!");
            }
            distParam = m_distParameter;
        }

        double linkParam = 0;
        if (m_linkFunction == LinkFunction.POWER || m_linkFunction == LinkFunction.ODDSPOWER) {
            if (Double.isNaN(m_linkParameter)) {
                throw new Exception("[GeneralRegression] no link parameter defined!");
            }
            linkParam = m_linkParameter;
        }

        for (int i = 0; i < r.length; i++) {
            responses[i] = m_linkFunction.eval(r[i], offset, trials, distParam, linkParam);
        }
    }

    /**
     * Computes responses for the ordinal multinomial model type.
     * 
     * @param incomingInst the raw incoming instance values (after missing value
     *                     replacement and outlier treatment etc).
     * @param responses    will hold the responses computed by the function
     * @throws Exception if a problem occurs.
     */
    private void computeResponseOrdinalMultinomial(double[] incomingInst, double[] responses) throws Exception {

        double[] r = responses.clone();

        double offset = 0;
        if (m_offsetVariable != null) {
            Attribute offsetAtt = m_miningSchema.getFieldsAsInstances().attribute(m_offsetVariable);
            if (offsetAtt == null) {
                throw new Exception("[GeneralRegression] unable to find offset variable " + m_offsetVariable + " in the mining schema!");
            }
            offset = incomingInst[offsetAtt.index()];
        } else if (!Double.isNaN(m_offsetValue)) {
            offset = m_offsetValue;
        }

        for (int i = 0; i < r.length; i++) {
            if (i == 0) {
                responses[i] = m_cumulativeLinkFunction.eval(r[i], offset);

            } else if (i == (r.length - 1)) {
                responses[i] = 1.0 - responses[i - 1];
            } else {
                responses[i] = m_cumulativeLinkFunction.eval(r[i], offset) - responses[i - 1];
            }
        }
    }

}
